{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "this notebook runs both cyto3 and cellpose-sam so uses cellpose==3.1.1.2 with the cellpose-sam net hacked in\n",
        "\n",
        "(see example env setup in benchmark_all_sam.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bw_-7fg3kLRC"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from cellpose import io, metrics, models, utils, transforms, denoise\n",
        "import time\n",
        "from tqdm import trange\n",
        "from pathlib import Path\n",
        "from natsort import natsorted\n",
        "import tifffile\n",
        "import matplotlib.pyplot as plt\n",
        "import benchmarks\n",
        "\n",
        "files, imgs, masks_true = benchmarks.load_dataset(\"cyto2\")\n",
        "diam_true = [utils.diameters(m)[0] for m in masks_true]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yd4v47FOkLRD"
      },
      "source": [
        "## size"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TCs_QK0VkLRD"
      },
      "outputs": [],
      "source": [
        "#from cellSAM import cellsam_pipeline\n",
        "from cellSAM import segment_cellular_image\n",
        "from cellpose import resnet_torch\n",
        "import cv2\n",
        "import torch\n",
        "from train_subsets import TransformerMP\n",
        "io.logger_setup()\n",
        "\n",
        "device = torch.device(\"cuda\")\n",
        "\n",
        "ps = 8\n",
        "backbone = \"vit_l\"\n",
        "net = TransformerMP(ps=ps, backbone=backbone).to(device)\n",
        "net.load_model(\"models/cpsam8_0_2100_8_402175188\", strict=False, multigpu=False)\n",
        "\n",
        "model = models.CellposeModel(gpu=True, nchan=3)\n",
        "net.eval()\n",
        "model.net = net\n",
        "\n",
        "cp_model = models.CellposeModel(gpu=True, model_type=\"cyto3\")\n",
        "\n",
        "\n",
        "aps = [[], [], []]\n",
        "masks_preds = []\n",
        "for sz in [10, 15, 30, 60, 90]:\n",
        "    diameters = diam_true.copy() * (30. / sz)\n",
        "    imgs_rsz = [transforms.resize_image(imgs[i].transpose(1,2,0), rsz=30./diameters[i]).transpose(2,0,1) for i in range(len(imgs))]\n",
        "    masks_true_rsz = [transforms.resize_image(masks_true[i], rsz=30./diameters[i], no_channels=True, interpolation=cv2.INTER_NEAREST) for i in range(len(imgs))]\n",
        "\n",
        "    for j in range(1):\n",
        "        if j==0:\n",
        "            masks_pred, flows, styles = model.eval(imgs_rsz, diameter=30., channels=None, augment=False,\n",
        "                                            bsize=256, tile_overlap=0.1, batch_size=64,\n",
        "                                            flow_threshold=0.4, cellprob_threshold=0, normalize=False)\n",
        "        elif j==1:\n",
        "            masks_pred, flows, styles = cp_model.eval(imgs_rsz, diameter=30., channels=[2,3],\n",
        "                                            bsize=224, tile_overlap=0.5, batch_size=64, augment=True,\n",
        "                                            flow_threshold=0.4, cellprob_threshold=0, normalize=False)\n",
        "        else:\n",
        "            if sz!=90:\n",
        "                masks_pred = []\n",
        "                bsize = 1024\n",
        "                for i in trange(len(imgs_rsz)):\n",
        "                    img = imgs_rsz[i][[0,2,1]].copy()\n",
        "                    Ly, Lx = img.shape[1:]\n",
        "                    Lyr = bsize if Ly > bsize and Ly > Lx else Ly\n",
        "                    Lxr = bsize if Lx > bsize and Lx >= Ly else Lx\n",
        "                    Lxr = int(np.round(bsize * (Lx / Ly))) if Ly > Lx and Lyr==bsize else Lxr\n",
        "                    Lyr = int(np.round(bsize * (Ly / Lx))) if Lx >= Ly and Lxr==bsize else Lyr\n",
        "                    if Lyr != Ly or Lxr != Lx:\n",
        "                        img = cv2.resize(img.transpose(1, 2, 0), (Lxr, Lyr), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)\n",
        "                    if Lyr < bsize or Lxr < bsize:\n",
        "                        padyx = [(int(np.floor((bsize-Lyr)/2)), int(np.ceil((bsize-Lyr)/2))),\n",
        "                                (int(np.floor((bsize-Lxr)/2)), int(np.ceil((bsize-Lxr)/2)))]\n",
        "                        img = np.pad(img, ((0,0), padyx[0], padyx[1]), mode='constant')\n",
        "                    else:\n",
        "                        padyx = [(0, 0), (0, 0)]\n",
        "                    try:\n",
        "                        masks, _, _ = segment_cellular_image(img, device='cuda')\n",
        "                    except:\n",
        "                        masks = np.zeros((bsize, bsize), dtype=\"uint16\")\n",
        "                    masks = masks[padyx[0][0]:bsize-padyx[0][1], padyx[1][0]:bsize-padyx[1][1]]\n",
        "                    if Lyr != Ly or Lxr != Lx:\n",
        "                        masks = cv2.resize(masks, (Lx, Ly), interpolation=cv2.INTER_NEAREST)\n",
        "                    masks_pred.append(masks)\n",
        "\n",
        "        ap, tp, fp, fn = metrics.average_precision(masks_true_rsz, masks_pred)\n",
        "        if sz==90 and j==2:\n",
        "            ap *= np.nan\n",
        "        else:\n",
        "            print(ap.mean(axis=0))\n",
        "        aps[j].append(ap)\n",
        "        if j==0:\n",
        "            masks_preds.append(masks_pred)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1dj4LpI7kLRE"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "aps = np.array(aps)\n",
        "plt.plot(aps[:,:,:,0].mean(axis=-1).T)\n",
        "ax = plt.gca()\n",
        "ax.set_xticks(np.arange(5))\n",
        "ax.set_xticklabels([\"10\", \"15\", \"30\", \"60\", \"90\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-Rgmk8OZkLRE"
      },
      "outputs": [],
      "source": [
        "np.save(\"size_invariance.npy\", {\"aps\": aps, \"masks_preds\": masks_preds})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QWppfEl-kLRE"
      },
      "source": [
        "## noise"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vn2eSZHrkLRF"
      },
      "outputs": [],
      "source": [
        "from cellSAM import segment_cellular_image\n",
        "from cellpose import resnet_torch\n",
        "import cv2\n",
        "import torch\n",
        "io.logger_setup()\n",
        "\n",
        "device = torch.device(\"cuda\")\n",
        "\n",
        "ps = 8\n",
        "backbone = \"vit_l\"\n",
        "net = TransformerMP(ps=ps, backbone=backbone).to(device)\n",
        "net.load_model(\"models/cpsam8_0_2100_8_402175188\", strict=False, multigpu=False)\n",
        "\n",
        "model = models.CellposeModel(gpu=True, nchan=3)\n",
        "net.eval()\n",
        "model.net = net\n",
        "\n",
        "cp_model = models.CellposeModel(gpu=True, model_type=\"cyto3\")\n",
        "\n",
        "\n",
        "nstr = {\"poisson\": \"denoise\", \"blur\": \"deblur\", \"downsample\": \"upsample\", \"aniso\": \"aniso\"}\n",
        "for ii, noise_type in enumerate([\"poisson\", \"blur\", \"downsample\", \"aniso\"]):\n",
        "    #if noise_type != \"downsample\":\n",
        "    #    continue\n",
        "    masks_preds = [[], [], []]\n",
        "    dat = np.load(f\"{noise_type}_invariance.npy\", allow_pickle=True).item()\n",
        "    aps = dat[\"aps\"]\n",
        "    masks_preds = dat[\"masks_preds\"]\n",
        "    #aps = np.zeros((4, 3, len(imgs), 3))\n",
        "    mstr = \"cyto3\" if noise_type!=\"aniso\" else \"cyto2\"\n",
        "    dn_model = denoise.DenoiseModel(gpu=True, model_type=f\"{nstr[noise_type]}_{mstr}\", chan2=True)\n",
        "    print(noise_type)\n",
        "    if noise_type==\"poisson\":\n",
        "        param = np.array([5, 2.5, 0.5])\n",
        "    elif noise_type==\"blur\":\n",
        "        param = np.array([2, 4, 8])# 48])\n",
        "    elif noise_type==\"downsample\":\n",
        "        param = np.array([2, 5, 10])\n",
        "    elif noise_type==\"aniso\":\n",
        "        param = np.array([2, 6, 12])\n",
        "    print(param)\n",
        "    ap = np.zeros((len(imgs), len(param)))\n",
        "\n",
        "    denoise.deterministic()\n",
        "    importlib.reload(denoise)\n",
        "    for k in range(len(param)):\n",
        "        for i in trange(len(imgs)):\n",
        "            img = np.maximum(0, imgs[i].copy())\n",
        "            if noise_type==\"poisson\":\n",
        "                params = {\"poisson\": 1.0, \"blur\": 0.0, \"downsample\": 0.0, \"pscale\": param[k]}\n",
        "            elif noise_type==\"blur\":\n",
        "                params = {\"poisson\": 1.0, \"pscale\": 120., \"blur\": 1.0, \"downsample\": 0.0,\n",
        "                            \"sigma0\": param[k], \"sigma1\": param[k]}\n",
        "            elif noise_type==\"downsample\":\n",
        "                params = {\"poisson\": 0.0, \"pscale\": 0., \"blur\": 1.0, \"downsample\": 1.0, \"ds\": param[k],\n",
        "                            \"sigma0\": param[k]/2, \"sigma1\": param[k]/2}\n",
        "            else:\n",
        "                params = {\"poisson\": 0.0, \"pscale\": 0., \"blur\": 1.0, \"downsample\": 1.0, \"ds\": param[k],\n",
        "                            \"sigma0\": param[k]/2, \"sigma1\": param[k]/2*0, \"iso\": False}\n",
        "            img = denoise.add_noise(torch.from_numpy(img).unsqueeze(0),\n",
        "                                    **params).cpu().numpy().squeeze()\n",
        "\n",
        "            for j in range(1):\n",
        "                if j==0:\n",
        "                    masks_pred0, flows, styles = model.eval(img, diameter=30., channels=None, augment=False,\n",
        "                                                    bsize=256, tile_overlap=0.1, batch_size=64,\n",
        "                                                    flow_threshold=0.4, cellprob_threshold=0, normalize=False)\n",
        "                    masks_preds[k].append(masks_pred0)\n",
        "                elif j==1:\n",
        "                    masks_pred0, flows, styles = cp_model.eval(img, diameter=diam_true[i], channels=[2,3],\n",
        "                                                    bsize=224, tile_overlap=0.5, batch_size=64, augment=True,\n",
        "                                                    flow_threshold=0.4, cellprob_threshold=0, normalize=False)\n",
        "                elif j==2:\n",
        "                    if noise_type==\"downsample\" or noise_type==\"aniso\":\n",
        "                        img_rsz = transforms.resize_image(img.transpose(1,2,0).copy(), rsz=30./diam_true[i]).transpose(2,0,1)\n",
        "                    else:\n",
        "                        img_rsz = img.copy()\n",
        "                    img_dn = dn_model.eval(img_rsz[[1,2]], diameter=None, channels=[1,2], channel_axis=0)\n",
        "                    masks_pred0, flows, styles = cp_model.eval(img_dn, diameter=diam_true[i] if noise_type!=\"downsample\" and noise_type!=\"aniso\" else None,\n",
        "                                                               channels=[1,2],\n",
        "                                                    bsize=224, tile_overlap=0.5, batch_size=64, augment=True,\n",
        "                                                    flow_threshold=0.4, cellprob_threshold=0, normalize=False)\n",
        "                    masks_pred0 = transforms.resize_image(masks_pred0, Ly=img.shape[1], Lx=img.shape[2], no_channels=True, interpolation=cv2.INTER_NEAREST)\n",
        "                    #plt.imshow(masks_pred0)\n",
        "                    #plt.show()\n",
        "                else:\n",
        "                    bsize = 512\n",
        "                    img = img[[0,2,1]].copy()\n",
        "                    Ly, Lx = img.shape[1:]\n",
        "                    Lyr = bsize if Ly > bsize and Ly > Lx else Ly\n",
        "                    Lxr = bsize if Lx > bsize and Lx >= Ly else Lx\n",
        "                    Lxr = int(np.round(bsize * (Lx / Ly))) if Ly > Lx and Lyr==bsize else Lxr\n",
        "                    Lyr = int(np.round(bsize * (Ly / Lx))) if Lx >= Ly and Lxr==bsize else Lyr\n",
        "                    if Lyr != Ly or Lxr != Lx:\n",
        "                        img = cv2.resize(img.transpose(1, 2, 0), (Lxr, Lyr), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)\n",
        "                    if Lyr < bsize or Lxr < bsize:\n",
        "                        padyx = [(int(np.floor((bsize-Lyr)/2)), int(np.ceil((bsize-Lyr)/2))),\n",
        "                                (int(np.floor((bsize-Lxr)/2)), int(np.ceil((bsize-Lxr)/2)))]\n",
        "                        img = np.pad(img, ((0,0), padyx[0], padyx[1]), mode='constant')\n",
        "                    else:\n",
        "                        padyx = [(0, 0), (0, 0)]\n",
        "                    try:\n",
        "                        masks, _, _ = segment_cellular_image(img, device='cuda')\n",
        "                    except:\n",
        "                        masks = np.zeros((bsize, bsize), dtype=\"uint16\")\n",
        "                    masks = masks[padyx[0][0]:bsize-padyx[0][1], padyx[1][0]:bsize-padyx[1][1]]\n",
        "                    if Lyr != Ly or Lxr != Lx:\n",
        "                           masks = cv2.resize(masks, (Lx, Ly), interpolation=cv2.INTER_NEAREST)\n",
        "                    masks_pred0 = masks\n",
        "                ap = metrics.average_precision([masks_true[i]], [masks_pred0])[0]\n",
        "                aps[j,k,i] = ap\n",
        "        print(aps[:,k,:,0].mean(axis=-1))\n",
        "\n",
        "    #np.save(f\"{noise_type}_invariance.npy\", {\"aps\": aps, \"masks_preds\": masks_preds})\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YqwceK15kLRG"
      },
      "source": [
        "## color"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "74XXCIfukLRG"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\")\n",
        "\n",
        "ps = 8\n",
        "backbone = \"vit_l\"\n",
        "net = TransformerMP(ps=ps, backbone=backbone).to(device)\n",
        "net.load_model(\"models/cpsam8_0_2100_8_402175188\", strict=False, multigpu=False)\n",
        "\n",
        "model = models.CellposeModel(gpu=True, nchan=3)\n",
        "net.eval()\n",
        "model.net = net\n",
        "\n",
        "irgb = [[0,1,2], [2,0,1], [1,2,0], 'random']\n",
        "rgb_title = ['RGB', 'BRG', 'GBR', 'Random \\n each']\n",
        "\n",
        "aps = []\n",
        "masks_preds = []\n",
        "test_data = []\n",
        "for i in range(4):\n",
        "    rgb = irgb[i]\n",
        "\n",
        "    np.random.seed(42)\n",
        "    test_data_copy = imgs.copy()\n",
        "\n",
        "    if rgb == 'random':\n",
        "        for j in range(len(test_data_copy)):\n",
        "            iswap = np.random.permutation(3)\n",
        "            test_data_copy[j] = test_data_copy[j][iswap]\n",
        "    elif rgb is not None:\n",
        "        for j in range(len(test_data_copy)):\n",
        "            test_data_copy[j] = test_data_copy[j][rgb]\n",
        "\n",
        "    masks_pred, flows, styles = model.eval(test_data_copy, normalize = False, tile_overlap = 0.1,\n",
        "                            bsize = 256,  diameter= None, #1. * diam_test[ind_im],\n",
        "                            augment = False, channels=None, niter = None, batch_size = 64)\n",
        "\n",
        "    app0 = metrics.average_precision(masks_true, masks_pred)[0]\n",
        "    aps.append(app0)\n",
        "    masks_preds.append(masks_pred)\n",
        "    test_data.append(test_data_copy)\n",
        "\n",
        "aps = np.array(aps)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W7JPFBD1kLRH"
      },
      "outputs": [],
      "source": [
        "np.save(f\"color_invariance.npy\", {\"aps\": aps, \"masks_preds\": masks_preds, \"test_data\": test_data})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-NRkfuFQkLRH"
      },
      "outputs": [],
      "source": [
        "plt.plot(aps[:,:,0].mean(axis=-1))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "cellsam",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
